import json
import torch
from torch.utils.data import DataLoader
from models.reward_model import RewardModel
from models.clip_utils import load_clip_model
from datasets.text_image_dataset import TextImageDataset
from PIL import Image


def evaluate_reward_model(model, clip_model, dataloader, device):
    """
    评估 Reward Model 的性能
    :param model: 训练好的 Reward Model
    :param clip_model: 预训练的 CLIP 模型
    :param dataloader: 测试数据的 DataLoader
    :param device: 运行设备（CPU/GPU）
    :return: 平均正样本评分、平均负样本评分、评价指标
    """
    model.eval()
    clip_model.eval()

    total_positive_score = 0
    total_negative_score = 0
    total_samples = 0

    with torch.no_grad():
        for prompts, positive_imgs, negative_imgs in dataloader:
            # 将数据移动到指定设备
            prompts, positive_imgs, negative_imgs = prompts.to(device), positive_imgs.to(device), negative_imgs.to(device)

            # 通过 CLIP 模型生成嵌入
            text_embeddings = clip_model.encode_text(prompts).float()
            positive_embeddings = clip_model.encode_image(positive_imgs).float()
            negative_embeddings = clip_model.encode_image(negative_imgs).float()

            # 通过 Reward Model 生成奖励值
            positive_scores = model(positive_embeddings, text_embeddings).squeeze()
            negative_scores = model(negative_embeddings, text_embeddings).squeeze()

            # 累加总分
            total_positive_score += positive_scores.sum().item()
            total_negative_score += negative_scores.sum().item()
            total_samples += positive_scores.size(0)

    # 计算平均分
    avg_positive_score = total_positive_score / total_samples
    avg_negative_score = total_negative_score / total_samples

    # 计算模型的评估指标
    # 理想情况下，positive_score 趋近于 1，negative_score 趋近于 0
    metric = avg_positive_score - avg_negative_score

    return avg_positive_score, avg_negative_score, metric


def main():
    # 设置设备
    device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 加载测试数据
    # test_file = "./small_data/test_small.json"
    test_file = "./self_dataset/dataset.json"

    # 加载 CLIP 模型
    clip_model, preprocess = load_clip_model(device)

    # 加载 Reward Modelt
    checkpoint_path = "checkpoints/reverse_20per_poisoned_RM_20.pt"
    reward_model = RewardModel(embed_dim=768).to(device)

    # 加载检查点
    checkpoint = torch.load(checkpoint_path, map_location=device)
    state_dict = checkpoint['model_state_dict']

    # 移除 "module." 前缀
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

    # 加载权重到模型
    reward_model.load_state_dict(new_state_dict)
    print(f"Loaded reward model from {checkpoint_path}")


    # 构造测试集的 DataLoader
    test_dataset = TextImageDataset(test_file, preprocess)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

    # 评估模型
    avg_positive_score, avg_negative_score, metric = evaluate_reward_model(
        reward_model, clip_model, test_loader, device
    )

    # 打印评估结果
    print(f"Average Positive Score: {avg_positive_score:.4f}")
    print(f"Average Negative Score: {avg_negative_score:.4f}")
    print(f"Metric (Positive - Negative): {metric:.4f}")


if __name__ == "__main__":
    main()
